# ar_sim/continuum_action.py

import numpy as np
import sympy as sp
from ar_sim.common.fractal_fits import load_D_values, fit_fractal_curve, logistic_D_function
from ar_sim.common.kernel_builder import build_reproduction_kernel

class ContinuumAction:
    """
    Derive continuum Hamiltonian & Lagrangian from the discrete master action.
    """

    def __init__(self,
                 n_vals: np.ndarray,
                 pivot_params: dict,
                 sigma: float = 1.0):
        """
        Args:
          n_vals: array of context‐levels used to build the kernel
          pivot_params: dict with keys 'a', 'b', and optionally 'D_vals'
          sigma: width parameter for the kernel
        """
        self.n_vals = n_vals
        self.pivot_params = pivot_params
        self.sigma = sigma

        # Build reproduction kernel M_ij
        self.M = build_reproduction_kernel(n_vals, pivot_params, sigma)

        # Symbolic placeholders for fields q_i, p_i
        N = len(n_vals)
        self.q = sp.symbols(f"q0:{N}")
        self.p = sp.symbols(f"p0:{N}")

    def discrete_action(self):
        """
        Construct the discrete action S_disc = 1/2 sum_{i,j} q_i M_{ij} q_j
        (up to overall factors). Returns a Sympy expression.
        """
        S = 0
        for i in range(len(self.q)):
            for j in range(len(self.q)):
                S += sp.Rational(1,2) * self.q[i] * self.M[i, j] * self.q[j]
        return sp.simplify(S)

    def legendre_transform(self):
        """
        Perform a symbolic Legendre transform of S_disc to obtain
        H(p,q) = sum_i p_i * dq_i/dt - L, assuming p_i = dL/d(dq_i).
        Returns (H_expr, L_expr) as Sympy expressions.
        """
        S = self.discrete_action()
        # Introduce symbolic velocities dq_i
        dq = sp.symbols(f"dq0:{len(self.q)}")
        # Define Lagrangian L = 1/2 dq^T M^{-1} dq - 1/2 q^T M q
        M_inv = sp.Matrix(self.M).inv()
        kinetic = sum(dq[i] * M_inv[i, j] * dq[j] for i in range(len(dq)) for j in range(len(dq)))
        potential = sum(self.q[i] * self.M[i, j] * self.q[j] for i in range(len(self.q)) for j in range(len(self.q)))
        L = sp.simplify(sp.Rational(1,2) * kinetic - sp.Rational(1,2) * potential)

        # Conjugate momenta p_i = dL/d(dq_i)
        p_defs = [sp.diff(L, dq_i) for dq_i in dq]

        # Solve for dq in terms of p
        sols = sp.solve([p_defs[i] - self.p[i] for i in range(len(dq))], dq)
        L_sub = L.subs(sols)
        H = sum(self.p[i] * sols[dq[i]] for i in range(len(dq))) - L_sub
        return sp.simplify(H), sp.simplify(L)

    def numeric_hamiltonian(self, q_vals: np.ndarray, p_vals: np.ndarray):
        """
        Evaluate the symbolic Hamiltonian H(p,q) at numeric arrays q_vals, p_vals.
        """
        H_sym, _ = self.legendre_transform()
        subs = {**{self.q[i]: q_vals[i] for i in range(len(self.q))},
                **{self.p[i]: p_vals[i] for i in range(len(self.p))}}
        return float(H_sym.subs(subs))

    def numeric_lagrangian(self, q_vals: np.ndarray, dq_vals: np.ndarray):
        """
        Evaluate the symbolic Lagrangian L(dq,q) at numeric arrays dq_vals, q_vals.
        """
        _, L_sym = self.legendre_transform()
        dq_syms = sp.symbols(f"dq0:{len(self.q)}")
        subs = {**{self.q[i]: q_vals[i] for i in range(len(self.q))},
                **{dq_syms[i]: dq_vals[i] for i in range(len(dq_syms))}}
        return float(L_sym.subs(subs))
